#!/usr/bin/env python

#####################################################################
#
# intfstat is a tool for summarizing l3 network statistics. 
#
#####################################################################

import argparse
import cPickle as pickle
import datetime
import sys
import os
import swsssdk
import sys
import time

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "1":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        test_path = os.path.join(modules_path, "sonic-utilities-tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, test_path)
        import mock_tables.dbconnector
except KeyError:
    pass

from collections import namedtuple, OrderedDict
from natsort import natsorted
from tabulate import tabulate
from utilities_common.netstat import ns_diff, ns_brate, ns_prate, table_as_json, STATUS_NA

NStats = namedtuple("NStats", "rx_b_ok, rx_p_ok, tx_b_ok, tx_p_ok,\
                    rx_b_err, rx_p_err, tx_b_err, tx_p_err,")

header = ['IFACE', 'RX_OK', 'RX_BPS', 'RX_PPS', 'RX_ERR', 
              'TX_OK', 'TX_BPS', 'TX_PPS', 'TX_ERR']

counter_names = (
    'SAI_ROUTER_INTERFACE_STAT_IN_OCTETS',
    'SAI_ROUTER_INTERFACE_STAT_IN_PACKETS',
    'SAI_ROUTER_INTERFACE_STAT_OUT_OCTETS',
    'SAI_ROUTER_INTERFACE_STAT_OUT_PACKETS',
    'SAI_ROUTER_INTERFACE_STAT_IN_ERROR_OCTETS',
    'SAI_ROUTER_INTERFACE_STAT_IN_ERROR_PACKETS',
    'SAI_ROUTER_INTERFACE_STAT_OUT_ERROR_OCTETS',
    'SAI_ROUTER_INTERFACE_STAT_OUT_ERROR_PACKETS'
)


COUNTER_TABLE_PREFIX = "COUNTERS:"
COUNTERS_RIF_NAME_MAP = "COUNTERS_RIF_NAME_MAP"
COUNTERS_RIF_TYPE_MAP = "COUNTERS_RIF_TYPE_MAP"

INTERFACE_TABLE_PREFIX = "PORT_TABLE:"
INTF_STATUS_VALUE_UP = 'UP'
INTF_STATUS_VALUE_DOWN = 'DOWN'

INTF_STATE_UP = 'U'
INTF_STATE_DOWN = 'D'
INTF_STATE_DISABLED = 'X'

class Intfstat(object):
    def __init__(self):
        self.db = swsssdk.SonicV2Connector(host='127.0.0.1')
        self.db.connect(self.db.COUNTERS_DB)
        self.db.connect(self.db.APPL_DB)

    def get_cnstat(self, rif=None):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            fields = [STATUS_NA] * (len(header) - 1)
            for pos, counter_name in enumerate(counter_names):
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
                if counter_data:
                    fields[pos] = str(counter_data)
            cntr = NStats._make(fields)
            return cntr

        # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()

        # Get the info from database
        counter_rif_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_RIF_NAME_MAP);


        if counter_rif_name_map is None:
            print "No %s in the DB!" % COUNTERS_RIF_NAME_MAP
            sys.exit(1)
        
        if rif and not rif in counter_rif_name_map:
            print "Interface %s missing from %s! Make sure it exists" % (rif, COUNTERS_RIF_NAME_MAP)
            sys.exit(2)

        if rif:
            cnstat_dict[rif] = get_counters(counter_rif_name_map[rif])
            return cnstat_dict

        for rif in natsorted(counter_rif_name_map):
            cnstat_dict[rif] = get_counters(counter_rif_name_map[rif]) 
        return cnstat_dict

    def get_intf_state(self, port_name):
        """
            Get the port state
        """
        full_table_id = PORT_STATUS_TABLE_PREFIX + port_name
        admin_state = self.db.get(self.db.APPL_DB, full_table_id, PORT_ADMIN_STATUS_FIELD)
        oper_state = self.db.get(self.db.APPL_DB, full_table_id, PORT_OPER_STATUS_FIELD)
        if admin_state is None or oper_state is None:
             return STATUS_NA
        elif admin_state.upper() == PORT_STATUS_VALUE_DOWN:
            return PORT_STATE_DISABLED
        elif admin_state.upper() == PORT_STATUS_VALUE_UP and oper_state.upper() == PORT_STATUS_VALUE_UP:
            return PORT_STATE_UP
        elif admin_state.upper() == PORT_STATUS_VALUE_UP and oper_state.upper() == PORT_STATUS_VALUE_DOWN:
            return PORT_STATE_DOWN
        else:
            return STATUS_NA

    def cnstat_print(self, cnstat_dict, use_json):
        """
            Print the cnstat.
        """
        table = []

        for key, data in cnstat_dict.iteritems():
            if key == 'time':
                continue

            table.append((key, data.rx_p_ok, STATUS_NA, STATUS_NA, data.rx_p_err,
                               data.tx_p_ok, STATUS_NA, STATUS_NA, data.tx_p_err))

        if use_json:
            print table_as_json(table, header)

        else:
            print tabulate(table, header, tablefmt='simple', stralign='right')

    def cnstat_diff_print(self, cnstat_new_dict, cnstat_old_dict, use_json):
        """
            Print the difference between two cnstat results.
        """

        table = []

        for key, cntr in cnstat_new_dict.iteritems():
            if key == 'time':
                time_gap = cnstat_new_dict.get('time') - cnstat_old_dict.get('time')
                time_gap = time_gap.total_seconds()
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)

            if old_cntr is not None:
                table.append((key,
                            ns_diff(cntr.rx_p_ok, old_cntr.rx_p_ok),
                            ns_brate(cntr.rx_b_ok, old_cntr.rx_b_ok, time_gap),
                            ns_prate(cntr.rx_p_ok, old_cntr.rx_p_ok, time_gap),
                            ns_diff(cntr.rx_p_err, old_cntr.rx_p_err),
                            ns_diff(cntr.tx_p_ok, old_cntr.tx_p_ok),
                            ns_brate(cntr.tx_b_ok, old_cntr.tx_b_ok, time_gap),
                            ns_prate(cntr.tx_p_ok, old_cntr.tx_p_ok, time_gap),
                            ns_diff(cntr.tx_p_err, old_cntr.tx_p_err)))
            else:
                table.append((key,
                            cntr.rx_p_ok,
                            STATUS_NA,
                            STATUS_NA,
                            cntr.rx_p_err,
                            cntr.tx_p_ok,
                            STATUS_NA,
                            STATUS_NA,
                            cntr.tx_p_err))
        if use_json:
            print table_as_json(table, header)
        else:
            print tabulate(table, header, tablefmt='simple', stralign='right')

    def cnstat_single_interface(self, rif, cnstat_new_dict, cnstat_old_dict):
        
        header = rif + '\n' + '-'*len(rif)
        body = """
        RX:
        %10s packets 
        %10s bytes 
        %10s error packets
        %10s error bytes
        TX:
        %10s packets
        %10s bytes
        %10s error packets
        %10s error bytes"""

        cntr = cnstat_new_dict.get(rif)
        
        if cnstat_old_dict:
            old_cntr = cnstat_old_dict.get(rif)
            if old_cntr:
                body = body % (ns_diff(cntr.rx_p_ok, old_cntr.rx_p_ok),
                            ns_diff(cntr.rx_b_ok, old_cntr.rx_b_ok),
                            ns_diff(cntr.rx_p_err, old_cntr.rx_p_err),
                            ns_diff(cntr.rx_b_err, old_cntr.rx_b_err),
                            ns_diff(cntr.tx_p_ok, old_cntr.tx_p_ok),
                            ns_diff(cntr.tx_b_ok, old_cntr.tx_b_ok),
                            ns_diff(cntr.tx_p_err, old_cntr.tx_p_err),
                            ns_diff(cntr.tx_b_err, old_cntr.tx_b_err))
        else:
            body = body % (cntr.rx_p_ok, cntr.rx_b_ok, cntr.rx_p_err,cntr.rx_b_err,
                           cntr.tx_p_ok, cntr.tx_b_ok, cntr.tx_p_err, cntr.tx_b_err)

        print header
        print body


def main():
    parser  = argparse.ArgumentParser(description='Display the interfaces state and counters',
                                        version='1.0.0',
                                        formatter_class=argparse.RawTextHelpFormatter,
                                        epilog="""
        Port state: (U)-Up (D)-Down (X)-Disabled
        Examples:
        intfstat -c -t test
        intfstat -t test
        intfstat -d -t test
        intfstat
        intfstat -r
        intfstat -a
        intfstat -p 20
        intfstat -i Vlan1000
        """)

    parser.add_argument('-c', '--clear', action='store_true', help='Copy & clear stats')
    parser.add_argument('-d', '--delete', action='store_true', help='Delete saved stats, either the uid or the specified tag')
    parser.add_argument('-D', '--delete-all', action='store_true', help='Delete all saved stats')
    parser.add_argument('-j', '--json', action='store_true', help='Display in JSON format')
    parser.add_argument('-t', '--tag', type=str, help='Save stats with name TAG', default=None)
    parser.add_argument('-i', '--interface', type=str, help='Show stats for a single interface', required=False)
    parser.add_argument('-p', '--period', type=int, help='Display stats over a specified period (in seconds).', default=0)
    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_saved_stats = args.delete
    delete_all_stats = args.delete_all
    use_json = args.json
    tag_name = args.tag if args.tag else ""
    uid = str(os.getuid())
    wait_time_in_seconds = args.period
    interface_name = args.interface if args.interface else ""

    # fancy filename with dashes: uid-tag-intf / uid-intf / uid-tag etc
    filename_components = [uid, tag_name, interface_name]
    cnstat_file = "-".join(filter(None, filename_components))

    cnstat_dir = "/tmp/intfstat-" + uid
    cnstat_fqn_file = cnstat_dir + "/" + cnstat_file

    if delete_all_stats:
        # There is nothing to delete
        if not os.path.isdir(cnstat_dir):
            sys.exit(0)

        for file in os.listdir(cnstat_dir):
            os.remove(cnstat_dir + "/" + file)

        try:
            os.rmdir(cnstat_dir)
            sys.exit(0)
        except IOError as e:
            print e.errno, e
            sys.exit(e)

    if delete_saved_stats:
        try:
            os.remove(cnstat_fqn_file)
        except IOError as e:
            if e.errno != ENOENT:
                print e.errno, e
                sys.exit(1)
        finally:
            if os.listdir(cnstat_dir) == []:
                os.rmdir(cnstat_dir)
            sys.exit(0)

    intfstat = Intfstat()
    cnstat_dict = intfstat.get_cnstat(rif=interface_name)

    # At this point, either we'll create a file or open an existing one.
    if not os.path.exists(cnstat_dir):
        try:
            os.makedirs(cnstat_dir)
        except IOError as e:
            print e.errno, e
            sys.exit(1)

    if save_fresh_stats:
        try:
            pickle.dump(cnstat_dict, open(cnstat_fqn_file, 'w'))
        except IOError as e:
            sys.exit(e.errno)
        else:
            print "Cleared counters"
            sys.exit(0)

    if wait_time_in_seconds == 0:
        if os.path.isfile(cnstat_fqn_file):
            try:
                cnstat_cached_dict = pickle.load(open(cnstat_fqn_file, 'r'))
                print "Last cached time was " + str(cnstat_cached_dict.get('time'))
                if interface_name:
                    intfstat.cnstat_single_interface(interface_name, cnstat_dict, cnstat_cached_dict)
                else:
                    intfstat.cnstat_diff_print(cnstat_dict, cnstat_cached_dict, use_json)
            except IOError as e:
                print e.errno, e
        else:
            if tag_name:
                print "\nFile '%s' does not exist" % cnstat_fqn_file
                print "Did you run 'intfstat -c -t %s' to record the counters via tag %s?\n" % (tag_name, tag_name)
            else:
                if interface_name:
                    intfstat.cnstat_single_interface(interface_name, cnstat_dict, None)
                else:
                    intfstat.cnstat_print(cnstat_dict, use_json)
    else:
        #wait for the specified time and then gather the new stats and output the difference.
        time.sleep(wait_time_in_seconds)
        print "The rates are calculated within %s seconds period" % wait_time_in_seconds
        cnstat_new_dict = intfstat.get_cnstat(rif=interface_name)
        if interface_name:
            intfstat.cnstat_single_interface(interface_name, cnstat_new_dict, cnstat_dict)
        else:
            intfstat.cnstat_diff_print(cnstat_new_dict, cnstat_dict, use_json)

if __name__ == "__main__":
    main()
